import os
import torch
import pandas as pd
import numpy as np
import random
from utils_dqa import compute_reference_features, save_images, compute_mmd, approximate_mmd, preprocess
from tqdm import tqdm
from glob import glob
import copy
import os
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
"""
Importing libraries to override Stable Diffusion in transformer library.
"""
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.image_processor import PipelineImageInput
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import (
    deprecate,
    is_torch_xla_available)

def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    sigmas: Optional[List[float]] = None,
    **kwargs,
):
    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps

if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
    # rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
    return noise_cfg


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


"""
Preparing reference datasets.
"""


reference_image_paths = {}
ref_path = 'data/'
class_labels = ['nurse','firefighter']

# Prepare reference image paths for each class and gender
for gender in ['male', 'female']:
    for target_class in tqdm(class_labels, desc=f'Processing {gender}'):
        
        # Define the pattern to match the specific class and gender images
        pattern = os.path.join(ref_path, f'{gender}_{target_class}_*_base.jpg')
        
        # Collect all matching image paths
        class_specific_img_paths = glob(pattern)
        
        # Store the list in the dictionary with (gender, target_class) as the key
        key = (gender, target_class)
        reference_image_paths[key] = class_specific_img_paths

encoder = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50').to(device) # Choose the most reliable image encoder by DQA.

# Path to save/load the reference features
reference_features_path = 'reference_features.pt'

if os.path.exists(reference_features_path):
    # Load the reference features from disk
    print("Loading reference features from disk...")
    reference_features = torch.load(reference_features_path)
    # Move tensors to the desired device
    for key in reference_features:
        reference_features[key] = reference_features[key].to(device)
else:
    # Compute the reference features and save them
    print("Computing reference features...")
    reference_features = {}
    for key in tqdm(reference_image_paths, desc="Computing Features"):
        image_paths = reference_image_paths[key]
        features = compute_reference_features(image_paths, encoder,device)
        reference_features[key] = features  # Already on CPU
    # Save the computed features to disk
    torch.save(reference_features, reference_features_path)
    print("Reference features saved to disk.")
    # Move features to the desired device
    for key in reference_features:
        reference_features[key] = reference_features[key].to(device)

"""
DQA-Guidance Term
"""
def compute_combined_objective(fx_t_A, fx_t_B, ref_features_A, ref_features_B):
    # Ensure all tensors require gradients if needed for backward pass
    fx_t_A = fx_t_A.requires_grad_(True)
    fx_t_B = fx_t_B.requires_grad_(True)
    max_D = 2000
    min_D = 50
    D = max_D
    while D>=min_D:
        try :
            MMD_A = compute_mmd(fx_t_A, ref_features_A)
            MMD_B = compute_mmd(fx_t_B, ref_features_B)
            MMD_total = compute_mmd(torch.cat([fx_t_A, fx_t_B], dim=0), torch.cat([ref_features_A, ref_features_B], dim=0))
            N = torch.abs(MMD_A - MMD_B)
            DQA = N/(MMD_total+1e-8)
            return DQA, MMD_total
        except:
            try : 
                MMD_A = approximate_mmd(fx_t_A, ref_features_A,D=D)
                MMD_B = approximate_mmd(fx_t_B, ref_features_B,D=D)
                MMD_total = approximate_mmd(torch.cat([fx_t_A, fx_t_B], dim=0), torch.cat([ref_features_A, ref_features_B], dim=0))
                N = torch.abs(MMD_A - MMD_B)
                DQA = N/(MMD_total+1e-8)
                return DQA, MMD_total
            except Exception as approx_error:
                print(f"Approximate MMD failed with D={D}: {str(approx_error)}. Reducing D...")
                D -= 50  

    
    
"""
DQA-Guidance Stable Diffusion.
It takes two prompts, prompt_A and prompt_B. Scheduler, latent features, graidents are computed separately.
"""
class CustomStableDiffusionPipeline(StableDiffusionPipeline):
    def preprocess_images(self, images):
    # Check if images are in numpy format and convert to tensor if needed
        if isinstance(images, np.ndarray):
            images = torch.from_numpy(images).to(device)

        # Ensure images are in float format for proper normalization and interpolation
        images = images.float()
        if images.shape[-1] == 3:  # Likely [B, H, W, C]
            images = images.permute(0, 3, 1, 2)
        # images: tensor of shape [B, C, H, W] in range [0,1]
        images_resized = torch.nn.functional.interpolate(
            images, size=(224, 224), mode='bilinear', align_corners=False
        )

        # Normalize the resized tensor images
        mean = torch.tensor([0.485, 0.456, 0.406], device=images.device).view(1, -1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=images.device).view(1, -1, 1, 1)
        images_normalized = (images_resized - mean) / std
        return images_normalized

    def __call__(
            self,
            prompt_A: Union[str, List[str]] = None,
            prompt_B: Union[str, List[str]] = None,
            key_A: Union[str, List[str]] = None,
            key_B: Union[str, List[str]] = None,
            lambda_reg: float = 7.5,
            lambda_2: float = 0.1,
            feature_extractor=None,
            reference_features=None,
            imagenet_preprocess = None,
            height: Optional[int] = None,
            width: Optional[int] = None,
            num_inference_steps: int = 50,
            timesteps: List[int] = None,
            sigmas: List[float] = None,
            guidance_scale: float = 7.5,
            negative_prompt: Optional[Union[str, List[str]]] = None,
            num_images_per_prompt: Optional[int] = 1,
            eta: float = 0.0,
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            latents: Optional[torch.Tensor] = None,
            latents_A: Optional[torch.Tensor] = None,
            latents_B: Optional[torch.Tensor] = None,
            prompt_embeds: Optional[torch.Tensor] = None,
            prompt_embeds_A: Optional[torch.Tensor] = None,
            prompt_embeds_B: Optional[torch.Tensor] = None,
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            ip_adapter_image: Optional[PipelineImageInput] = None,
            ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
            output_type: Optional[str] = "pil",
            return_dict: bool = True,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            guidance_rescale: float = 0.0,
            clip_skip: Optional[int] = None,
            callback_on_step_end: Optional[
                Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
            ] = None,
            callback_on_step_end_tensor_inputs: List[str] = ["latents"],
            **kwargs,
        ):
            
            callback = kwargs.pop("callback", None)
            callback_steps = kwargs.pop("callback_steps", None)

            if callback is not None:
                deprecate(
                    "callback",
                    "1.0.0",
                    "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
                )
            if callback_steps is not None:
                deprecate(
                    "callback_steps",
                    "1.0.0",
                    "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
                )

            if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
                callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
            self.feature_extractor=feature_extractor.to(device)
            self.reference_features=reference_features
            self.imagenet_preprocess = imagenet_preprocess
            # 0. Default height and width to unet
            height = height or self.unet.config.sample_size * self.vae_scale_factor
            width = width or self.unet.config.sample_size * self.vae_scale_factor
            
            # 1. Check inputs. Raise error if not correct
            with torch.no_grad():
                self.check_inputs(
                    prompt_A,
                    height,
                    width,
                    callback_steps,
                    negative_prompt,
                    prompt_embeds_A,
                    negative_prompt_embeds,
                    ip_adapter_image,
                    ip_adapter_image_embeds,
                    callback_on_step_end_tensor_inputs,
                )
                self.check_inputs(
                    prompt_B,
                    height,
                    width,
                    callback_steps,
                    negative_prompt,
                    prompt_embeds_B,
                    negative_prompt_embeds,
                    ip_adapter_image,
                    ip_adapter_image_embeds,
                    callback_on_step_end_tensor_inputs,
                )
                self.lambda_reg = lambda_reg
                self._guidance_scale = guidance_scale
                self._guidance_rescale = guidance_rescale
                self._clip_skip = clip_skip
                self._cross_attention_kwargs = cross_attention_kwargs
                self._interrupt = False

                # 2. Define call parameters
                if prompt_A is not None and isinstance(prompt_A, str):
                    batch_size = 1
                elif prompt_A is not None and isinstance(prompt_A, list):
                    batch_size = len(prompt_A)
                else:
                    batch_size = prompt_embeds.shape[0]

                # 3. Encode input prompt
                lora_scale = (
                    self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
                )
            with torch.no_grad():
                prompt_embeds_A, negative_prompt_embeds_A = self.encode_prompt(
                    prompt_A,
                    device,
                    num_images_per_prompt,
                    self.do_classifier_free_guidance,
                    negative_prompt,
                    prompt_embeds=prompt_embeds_A,
                    negative_prompt_embeds=negative_prompt_embeds,
                    lora_scale=lora_scale,
                    clip_skip=self.clip_skip,
                )
                prompt_embeds_B, negative_prompt_embeds_B = self.encode_prompt(
                    prompt_B,
                    device,
                    num_images_per_prompt,
                    self.do_classifier_free_guidance,
                    negative_prompt,
                    prompt_embeds=prompt_embeds_B,
                    negative_prompt_embeds=negative_prompt_embeds,
                    lora_scale=lora_scale,
                    clip_skip=self.clip_skip,
                )

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            
            if self.do_classifier_free_guidance:
                prompt_embeds_A = torch.cat([negative_prompt_embeds_A, prompt_embeds_A])
                prompt_embeds_B = torch.cat([negative_prompt_embeds_B, prompt_embeds_B])

            if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
                image_embeds = self.prepare_ip_adapter_image_embeds(
                    ip_adapter_image,
                    ip_adapter_image_embeds,
                    device,
                    batch_size * num_images_per_prompt,
                    self.do_classifier_free_guidance,
                )

            # 4. Prepare timesteps
            timesteps, num_inference_steps = retrieve_timesteps(
                self.scheduler, num_inference_steps, device, timesteps, sigmas
            )

            # 5. Prepare latent variables
            num_channels_latents = self.unet.config.in_channels
            
            latents_A = self.prepare_latents(
                batch_size * num_images_per_prompt,
                num_channels_latents,
                height,
                width,
                prompt_embeds_A.dtype,
                device,
                generator,
                latents_A,
            ).requires_grad_(self.lambda_reg > 0)
            latents_B = self.prepare_latents(
                batch_size * num_images_per_prompt,
                num_channels_latents,
                height,
                width,
                prompt_embeds_B.dtype,
                device,
                generator,
                latents_B,
            ).requires_grad_(self.lambda_reg > 0)
            # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
            extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

            # 6.1 Add image embeds for IP-Adapter
            added_cond_kwargs = (
                {"image_embeds": image_embeds}
                if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
                else None
            )

            # 6.2 Optionally get Guidance Scale Embedding
            timestep_cond = None
            if self.unet.config.time_cond_proj_dim is not None:
                guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
                timestep_cond = self.get_guidance_scale_embedding(
                    guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
                ).to(device=device, dtype=latents_A.dtype)

            # 7. Denoising loop
            num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
            self._num_timesteps = len(timesteps)

            # Retrieve reference features
            with torch.no_grad():
                ref_features_A = self.reference_features.get(key_A, torch.tensor([]).to(device))
                ref_features_B = self.reference_features.get(key_B, torch.tensor([]).to(device))
                        
            # Initialize two schedulers
            # Copy the original scheduler for separate use
            self.scheduler_A = copy.deepcopy(self.scheduler)
            self.scheduler_B = copy.deepcopy(self.scheduler)
            # Set timesteps for each scheduler
            self.scheduler_A.set_timesteps(num_inference_steps, device=device)
            self.scheduler_B.set_timesteps(num_inference_steps, device=device)
            with self.progress_bar(total=num_inference_steps) as progress_bar:
                for i, t in enumerate(timesteps):
                    if self.interrupt:
                        continue
                    with torch.no_grad():
                        # expand the latents if we are doing classifier free guidance
                        latent_model_input_A = torch.cat([latents_A] * 2) if self.do_classifier_free_guidance else latents_A
                        latent_model_input_A = self.scheduler_A.scale_model_input(latent_model_input_A, t)
                        latent_model_input_B = torch.cat([latents_B] * 2) if self.do_classifier_free_guidance else latents_B
                        latent_model_input_B = self.scheduler_B.scale_model_input(latent_model_input_B, t)
                        noise_pred_A = self.unet(
                            latent_model_input_A,
                            t,
                            encoder_hidden_states=prompt_embeds_A,
                            timestep_cond=timestep_cond,
                            cross_attention_kwargs=self.cross_attention_kwargs,
                            added_cond_kwargs=added_cond_kwargs,
                            return_dict=False,
                        )[0]
                        noise_pred_B = self.unet(
                            latent_model_input_B,
                            t,
                            encoder_hidden_states=prompt_embeds_B,
                            timestep_cond=timestep_cond,
                            cross_attention_kwargs=self.cross_attention_kwargs,
                            added_cond_kwargs=added_cond_kwargs,
                            return_dict=False,
                        )[0]

                        # perform guidance
                        if self.do_classifier_free_guidance:
                            noise_pred_uncond_A, noise_pred_text_A = noise_pred_A.chunk(2)
                            noise_pred_uncond_B, noise_pred_text_B = noise_pred_B.chunk(2)
                            noise_pred_A = noise_pred_uncond_A + self.guidance_scale * (noise_pred_text_A - noise_pred_uncond_A)
                            noise_pred_B = noise_pred_uncond_B + self.guidance_scale * (noise_pred_text_B - noise_pred_uncond_B)

                        if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
                            # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                            noise_pred_A = rescale_noise_cfg(noise_pred_A, noise_pred_text_A, guidance_rescale=self.guidance_rescale)
                            noise_pred_B = rescale_noise_cfg(noise_pred_B, noise_pred_text_B, guidance_rescale=self.guidance_rescale)
                                        
                    latents_A.requires_grad_()
                    latents_B.requires_grad_()
                    latents_A.retain_grad()
                    latents_B.retain_grad()
                    if self.lambda_reg > 0:
                        image_A = self.vae.decode(latents_A / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
                        has_nsfw_concept= None
                        if has_nsfw_concept is None:
                            do_denormalize = [True] * image_A.shape[0]
                        else:
                            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept_A]

                        image_B = self.vae.decode(latents_B / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]

                        image_A = self.image_processor.postprocess(image_A, output_type='pt', do_denormalize=do_denormalize)
                        image_B = self.image_processor.postprocess(image_B, output_type='pt', do_denormalize=do_denormalize)


                        # Preprocess images and extract features
                        x_t_A = self.preprocess_images(image_A)
                        x_t_B = self.preprocess_images(image_B)
                        
                        self.feature_extractor.eval()

                        f_x_t_A = self.feature_extractor(x_t_A.to(device))
                        f_x_t_B = self.feature_extractor(x_t_B.to(device))

                        # Compute DQA metric
                        DQA, D_AB= compute_combined_objective(f_x_t_A, f_x_t_B, ref_features_A, ref_features_B)
                        loss = lambda_reg * DQA + lambda_2 * D_AB
                        loss.backward()
                        
                        with torch.no_grad():
                            noise_pred_A = noise_pred_A +  latents_A.grad
                            noise_pred_B = noise_pred_B +  latents_B.grad
                            
                        # Clear gradients for next step
                        if latents_A.grad is not None:
                            latents_A.grad.zero_()
                        if latents_B.grad is not None:
                            latents_B.grad.zero_()
                    # compute the previous noisy sample x_t -> x_t-1
                    latents_A = latents_A.detach().requires_grad_(True)
                    latents_B = latents_B.detach().requires_grad_(True)
                    latents_A.retain_grad()
                    latents_B.retain_grad()
                    latents_A = self.scheduler_A.step(noise_pred_A, t, latents_A, **extra_step_kwargs, return_dict=False)[0]
                    latents_B = self.scheduler_B.step(noise_pred_B, t, latents_B, **extra_step_kwargs, return_dict=False)[0]
                    
                    

                    if callback_on_step_end is not None:
                        callback_kwargs = {}
                        for k in callback_on_step_end_tensor_inputs:
                            callback_kwargs[k] = locals()[k]
                        callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
                        latents_A = callback_outputs.pop("latents_A", latents_A)
                        prompt_embeds_A = callback_outputs.pop("prompt_embeds_A", prompt_embeds_A)
                        latents_B = callback_outputs.pop("latents_B", latents_B)
                        prompt_embeds_B = callback_outputs.pop("prompt_embeds_B", prompt_embeds_B)
                        negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

                    # call the callback, if provided
                    if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                        progress_bar.update()
                        if callback is not None and i % callback_steps == 0:
                            step_idx = i // getattr(self.scheduler, "order", 1)
                            callback(step_idx, t, latents)

                    if XLA_AVAILABLE:
                        xm.mark_step()
                if not output_type == "latent":
                    image_A = self.vae.decode(latents_A / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
                        0
                    ]
                    has_nsfw_concept_A = None
                else:
                    image_A = latents_A
                    has_nsfw_concept_A = None

                if has_nsfw_concept_A is None:
                    do_denormalize = [True] * image_A.shape[0]
                else:
                    do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept_A]
                image_A = self.image_processor.postprocess(image_A.detach(), output_type=output_type, do_denormalize=do_denormalize)


                if not output_type == "latent":
                    image_B = self.vae.decode(latents_B / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
                        0
                    ]
                    has_nsfw_concept_B = None
                else:
                    image_B = latents_B
                    has_nsfw_concept_B = None

                if has_nsfw_concept_B is None:
                    do_denormalize = [True] * image_B.shape[0]
                else:
                    do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept_B]
                image_B = self.image_processor.postprocess(image_B.detach(), output_type=output_type, do_denormalize=do_denormalize)
                # Offload all models
                self.maybe_free_model_hooks()

                return StableDiffusionPipelineOutput(images=image_A, nsfw_content_detected=has_nsfw_concept_A), StableDiffusionPipelineOutput(images=image_B, nsfw_content_detected=has_nsfw_concept_B)


# Load the model
# model_id = "Nihirc/Prompt2MedImage" # Medical image case. Needs medical images (training set of Chest X-ray)
model_id = "runwayml/stable-diffusion-v1-5" # Human image case.

# Hyperparameters 
lambda_values = [20]
lambda2s = [1]

guidance_scales = 7.5 # Default setting for Stable Diffusion
batch_size = 2 # Number of prompts per batch
num_batches = 50  # Number of batches to get 200 images (25*8=200)


# Loop over hyperparameters
for lambda_2 in tqdm(lambda2s):
    for lambda_reg in tqdm(lambda_values):

        pipe = CustomStableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16
        )
        
        pipe=pipe.to(device)
        encoder.to(device)
        base_output_dir = f'generation_results/resnet_lambda_{lambda_reg}_lambda2_{lambda_2}_guidance_{7.5}'

        
        for profession in tqdm(class_labels):
            print(f'Run {profession}')
            prompts_A = []
            prompts_B = []
            keys = []
            prompt_A = f'a photo of a man who works as a {profession}'
            prompts_A = [prompt_A] * batch_size
            prompt_B = f'a photo of a woman who works as a {profession}'
            prompts_B = [prompt_B] * batch_size
            

            # Set seeds for reproducibility
            seed = 0
            random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            np.random.seed(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmarks = False
            os.environ['PYTHONHASHSEED'] = str(seed)

            # Loop over batches to generate images
            for batch_idx in tqdm(range(num_batches)):
                # Generate images
                images_A, images_B = pipe(
                    prompts_A,  # Male prompts
                    prompts_B,  # Female prompts
                    key_A=('male', profession),
                    key_B=('female',profession),
                    lambda_reg=lambda_reg,
                    lambda_2=lambda_2,
                    feature_extractor=encoder,
                    reference_features=reference_features,
                    imagenet_preprocess = preprocess
                )
                # Skip if images are None
                if images_A is None or images_B is None:
                    continue
                
                # Save images
                save_images(images_A.images, base_output_dir, f'lambda_{lambda_reg}_lambda2_{lambda_2}_male_{profession}', batch_idx)
                save_images(images_B.images, base_output_dir, f'lambda_{lambda_reg}_lambda2_{lambda_2}_female_{profession}', batch_idx)
